import torch
import torch.nn as nn
import cv2 as cv
import numpy as np


def letter_box(img, new_shape):
    """将图片等比例缩放调整到指定边长的正方形,剩下的填充"""
    shape = img.shape[:2]  # [h, w]
    r = min(new_shape / shape[0], new_shape / shape[1])  # scale ratio (new / old)
    new_unpad = int(round(shape[1] * r)), int(round(shape[0] * r))
    img = cv.resize(img, new_unpad, interpolation=cv.INTER_LINEAR)
    dw, dh = (new_shape - new_unpad[0]) / 2, (new_shape - new_unpad[1]) / 2  # wh padding
    top, bottom = int(round(dh - 0.1)), int(round(dh + 0.1))  # 计算上下两侧的padding
    left, right = int(round(dw - 0.1)), int(round(dw + 0.1))  # 计算左右两侧的padding
    img = cv.copyMakeBorder(img, top, bottom, left, right, cv.BORDER_CONSTANT, value=(255, 255, 255))  # add border
    return img


def letter_box_rectangle(img):
    img = cv.resize(img, (512, 512), interpolation=cv.INTER_LINEAR)
    img = cv.copyMakeBorder(img, 0, 0, 64, 64, cv.BORDER_CONSTANT, value=(255, 255, 255))  # add border
    return img


def freeze_param(model, exclude='none'):
    """冻结网络参数
    :param model:网络模型
    :param exclude:哪部分不冻结,'classify'或'bbr'
    """
    if exclude == 'classify':
        pass
    elif exclude == 'bbr':
        for param in model.backbone.parameters():
            param.requires_grad = False
    else:
        print('Wrong exclude param')


def param_disturb(model, var=1e-5, device=torch.device('cpu')):
    with torch.no_grad():
        for name, m in model.named_modules():
            if isinstance(m, nn.Conv3d):
                disturb = var * torch.randn(m.weight.shape).to(device)
                m.weight += disturb
                if m.bias is not None:
                    disturb = var * torch.randn(m.bias.shape).to(device)
                    m.bias += disturb


def image_plot_conf(img, predict):
    predmax = predict.argmax(dim=1).item()
    label = 'JumpingJack', 'SoccerJuggling', 'TaiChi', 'WallPushups', 'WritingOnBoard'
    img = np.concatenate([img, 255*np.ones([30, 320, 3])], axis=0)
    img = img.astype(np.uint8)
    predsoftmax = nn.functional.softmax(predict, dim=1)
    text = label[predmax] + '  {:.4f}'.format(max(predsoftmax[0]))
    img = cv.putText(img, text, (0, 260), cv.FONT_HERSHEY_PLAIN, 1.2, (0, 0, 255), 1)
    return img
